{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from db.datasets import datasets\n",
    "from config import system_configs\n",
    "import json, os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import average_precision_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label file: None\n"
     ]
    }
   ],
   "source": [
    "with open('config/KPDetection.json', \"r\") as f:\n",
    "    configs = json.load(f)\n",
    "    \n",
    "split = 'valchart'\n",
    "\n",
    "configs[\"system\"][\"data_dir\"] = \"./data/EC400K_ChartLLM/clsdata(1031)/cls\"\n",
    "configs[\"system\"][\"cache_dir\"] = \"./data/cache\"\n",
    "\n",
    "configs[\"system\"][\"dataset\"] = \"Chart\"\n",
    "configs[\"system\"][\"snapshot_name\"] = \"KPDetection_320.pkl\"\n",
    "system_configs.update_config(configs[\"system\"])\n",
    "db = datasets[\"Chart\"](configs[\"db\"], split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pie_center(a, b, c):\n",
    "    a,b,c = np.array(a), np.array(b), np.array(c)\n",
    "    ca = c - a\n",
    "    cb = c - b\n",
    "    cosine_angle = np.dot(ca, cb) / (np.linalg.norm(ca) * np.linalg.norm(cb))\n",
    "    angle = np.arccos(cosine_angle)\n",
    "    r_square = (ca**2).sum()\n",
    "    \n",
    "    if ca[0]*cb[1]-ca[1]*cb[0] >= 0:\n",
    "        return (a[0]+b[0]+c[0])/3., (a[1]+b[1]+c[1])/3., 0.5 * angle * r_square\n",
    "    else:\n",
    "        return 2*c[0]-(a[0]+b[0]+c[0])/3., 2*c[1]-(a[1]+b[1]+c[1])/3., np.pi * r_square - 0.5 * angle * r_square\n",
    "\n",
    "def get_points(gts, preds, chartType):\n",
    "    gt_keys, gt_cens = [], []\n",
    "    area = 0\n",
    "    \n",
    "    if chartType == 'vbar_categorical':\n",
    "        for bbox in gts.tolist():\n",
    "            area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) \n",
    "            gt_keys.append((bbox[0],bbox[1], area))\n",
    "            gt_keys.append((bbox[2],bbox[3], area))\n",
    "            gt_cens.append( ( (bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2, area ) )\n",
    "    elif chartType == 'pie':\n",
    "        for bbox in gts.tolist():\n",
    "            a, b, c = (bbox[0], bbox[1]), (bbox[2], bbox[3]), (bbox[4], bbox[5])\n",
    "            xce, yce, area = get_pie_center(a,b,c)                \n",
    "            gt_keys.append((bbox[0],bbox[1], area))\n",
    "            gt_keys.append((bbox[2],bbox[3], area))\n",
    "            gt_keys.append((bbox[4],bbox[5], area))\n",
    "            gt_cens.append((xce, yce, area))   \n",
    "    elif chartType == 'line':\n",
    "        for bbox in gts[0]:\n",
    "            detection = np.array(bbox)\n",
    "            if len(detection) <= 1: continue\n",
    "            elif len(detection)//2 % 2 == 0:\n",
    "                mid = len(detection) // 2\n",
    "                xce, yce = (detection[mid-2] + detection[mid]) / 2, (detection[mid-1] + detection[mid+1]) / 2\n",
    "            else:\n",
    "                mid = len(detection) // 2\n",
    "                xce, yce = detection[mid-1].copy(), detection[mid].copy()\n",
    "            assert len(detection) % 2 == 0\n",
    "            xs = detection[0:len(detection):2]\n",
    "            ys = detection[1:len(detection):2]\n",
    "            area = (max(max(xs) - min(xs), max(ys) - min(ys)) / len(detection) * 2) ** 2\n",
    "                \n",
    "            for x, y in zip(xs, ys):\n",
    "                gt_keys.append((x,y, area))\n",
    "            gt_cens.append((xce, yce, area))   \n",
    "\n",
    "    pred_keys, pred_cens = [], []\n",
    "    if '1' not in preds[0]: # baseline predictions\n",
    "        if chartType == 'pie':\n",
    "            pred_keys.append((preds[0][0][0], preds[0][0][1], preds[0][-1]))\n",
    "            for pred in preds:\n",
    "                pred_keys.append((pred[1][0], pred[1][1], pred[-1]))\n",
    "        elif chartType == 'line':\n",
    "            for pred in preds:\n",
    "                pred_groups.append(np.array(pred))\n",
    "        else:\n",
    "            for pred in preds:\n",
    "                pred_keys.append((pred[0],pred[1], 1.))\n",
    "                pred_keys.append((pred[2],pred[3], 1.))\n",
    "    else:  \n",
    "        for point in preds[0]['1']:\n",
    "            pred_keys.append((point[2],point[3], point[0]))\n",
    "        for point in preds[1]['1']:\n",
    "            pred_cens.append((point[2],point[3], point[0]))\n",
    "    return gt_keys, gt_cens, pred_keys, pred_cens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def OKS(gt_p, pred_p):\n",
    "    d2 = (gt_p[0] - pred_p[0]) ** 2 + (gt_p[1] - pred_p[1]) ** 2\n",
    "    k2 = 0.1\n",
    "    s2 = gt_p[2]\n",
    "    return np.exp(d2/(s2 * k2) * (-1))\n",
    "\n",
    "def computeTargetLabel(gt_ps, pred_ps, thres=0.75):\n",
    "    y_true = []\n",
    "    for pred_p in pred_ps:\n",
    "        found = False\n",
    "        for gt_p in gt_ps:\n",
    "            if OKS(gt_p, pred_p) > thres:\n",
    "                y_true.append(1)\n",
    "                found = True\n",
    "                break\n",
    "        if not found:\n",
    "            y_true.append(0)\n",
    "    return y_true\n",
    "\n",
    "#用于计算给定阈值（默认为 0.75）下，哪些 ground truth 点（gt_ps）被预测点（pred_ps）成功检测到。\n",
    "def computeDetectedGT(gt_ps, pred_ps, thres=0.75):\n",
    "    # 初始化一个空列表 y_true，用于存储每个 ground truth 点是否被成功检测到（1 表示检测成功，0 表示未检测到）。\n",
    "    y_true = []\n",
    "    for gt_p in gt_ps:\n",
    "        found = False\n",
    "        for pred_p in pred_ps:\n",
    "            if OKS(gt_p, pred_p) > thres:\n",
    "                y_true.append(1)\n",
    "                found = True\n",
    "                break\n",
    "        # 如果遍历所有预测点后，found 仍为 False，说明当前 ground truth 点未被检测到。\n",
    "        if not found:\n",
    "            y_true.append(0)\n",
    "    return y_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n",
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/50 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "list index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[7], line 7\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[39mprint\u001b[39m(max_iter)\n\u001b[0;32m      6\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m tqdm(\u001b[39mrange\u001b[39m(\u001b[39m50\u001b[39m)):\n\u001b[1;32m----> 7\u001b[0m     db_ind \u001b[39m=\u001b[39m db\u001b[39m.\u001b[39;49mdb_inds[i]\n\u001b[0;32m      8\u001b[0m     image_file \u001b[39m=\u001b[39m db\u001b[39m.\u001b[39mimage_file(db_ind)\n\u001b[0;32m      9\u001b[0m     gts \u001b[39m=\u001b[39m db\u001b[39m.\u001b[39mdetections(db_ind)\n",
      "\u001b[1;31mIndexError\u001b[0m: list index out of range"
     ]
    }
   ],
   "source": [
    "mAP_keys = []\n",
    "mAP_cens = []\n",
    "print(db.db_inds)\n",
    "max_iter = len(db.db_inds)\n",
    "print(max_iter)\n",
    "for i in tqdm(range(50)):\n",
    "    db_ind = db.db_inds[i]\n",
    "    image_file = db.image_file(db_ind)\n",
    "    gts = db.detections(db_ind)\n",
    "    print(image_file.split('/')[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['c4c0627d16eb05af88880a92faec6aaa_amFyY2hpdmVzLmNvbQkxOTIuMTg1Ljk4LjE5OA==-5-0.png', 'c49285ca77f6aff6214dad492b688113_d3d3LmRhbmUuZ292LmNvCTE3MC4yMzguNjQuNzg=.xls-0-0.png', 'c4c0627d16eb05af88880a92faec6aaa_amFyY2hpdmVzLmNvbQkxOTIuMTg1Ljk4LjE5OA==-6-0.png']\n"
     ]
    }
   ],
   "source": [
    "with open('evaluation/KPDetection50000.json') as f:\n",
    "    prediction = json.load(f)\n",
    "    \n",
    "mAP_keys = []\n",
    "mAP_cens = []\n",
    "max_iter = db.db_inds.size\n",
    "print(list(prediction.keys())[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/3695 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'chartType' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 14\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[39mif\u001b[39;00m preds \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m \u001b[39mlen\u001b[39m(preds) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m: \u001b[39mcontinue\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(preds) \u001b[39m==\u001b[39m \u001b[39m3\u001b[39m \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(preds[\u001b[39m2\u001b[39m]) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m: \u001b[39mcontinue\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m gt_keys, gt_cens, pred_keys, pred_cens \u001b[39m=\u001b[39m get_points(gts, preds, chartType)\n\u001b[1;32m     16\u001b[0m \u001b[39m# 计算关于关键点（keys）的评估指标\u001b[39;00m\n\u001b[1;32m     17\u001b[0m y_true_keys \u001b[39m=\u001b[39m computeTargetLabel(gt_keys, pred_keys)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'chartType' is not defined"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(max_iter)):\n",
    "    db_ind = db.db_inds[i]\n",
    "    #print(db_ind)\n",
    "    image_file = db.image_file(db_ind)\n",
    "    #print(image_file)\n",
    "    gts = db.detections(db_ind)\n",
    "    #print(gts)\n",
    "    # 如果没有 ground truth 数据，则跳过当前迭代。\n",
    "    if gts is None or len(gts) == 0: continue\n",
    "    #print(image_file.split('/')[-1])\n",
    "    preds = prediction[image_file.split('/')[-1]]\n",
    "    if preds is None or len(preds) == 0: continue\n",
    "    if len(preds) == 3 and len(preds[2]) == 0: continue\n",
    "    gt_keys, gt_cens, pred_keys, pred_cens = get_points(gts, preds, chartType)\n",
    "    \n",
    "    # 计算关于关键点（keys）的评估指标\n",
    "    y_true_keys = computeTargetLabel(gt_keys, pred_keys)\n",
    "    y_score_keys= [key[2] for key in pred_keys]\n",
    "    \n",
    "    detected_gt_keys = computeDetectedGT(gt_keys, pred_keys)\n",
    "    miss_count = len(detected_gt_keys) - sum(detected_gt_keys)\n",
    "    # 对漏检的 ground truth，其真实标签应为 1。\n",
    "    y_true_keys = y_true_keys + [1] * miss_count\n",
    "    # 漏检的 ground truth 的预测得分应为 0。\n",
    "    y_score_keys = y_score_keys + [0] * miss_count\n",
    "    \n",
    "    score = average_precision_score(y_true_keys, y_score_keys)\n",
    "#     if score < 0.3:\n",
    "#         print(image_file)\n",
    "    mAP_keys = np.append(mAP_keys,score)\n",
    "    \n",
    "    # cens\n",
    "    y_true_cens = computeTargetLabel(gt_cens, pred_cens)\n",
    "    y_score_cens= [key[2] for key in pred_cens]\n",
    "    \n",
    "    detected_gt_cens = computeDetectedGT(gt_cens, pred_cens)\n",
    "    miss_count = len(detected_gt_cens) - sum(detected_gt_cens)\n",
    "    y_true_cens = y_true_cens + [1] * miss_count\n",
    "    y_score_cens = y_score_cens + [0] * miss_count\n",
    "    \n",
    "    mAP_cens = np.append(mAP_cens, average_precision_score(y_true_cens, y_score_cens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mAP_keys = np.array(mAP_keys)\n",
    "mAP_cens = np.array(mAP_cens)\n",
    "print('mAP for keypoints:', mAP_keys[~np.isnan(mAP_keys)].mean(), \" mAP for center points:\",mAP_cens[~np.isnan(mAP_cens)].mean())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
